# -*- coding: utf-8 -*-

from .last_iterate import LastIterate
from .oracle_model_selection import OracleModelSelection

from .weight_avg import WeightAVG
from .beta_moving_avg import BetaMovingAVG
from .exp_moving_avg import ExpMovingAVG


def get_model_selection_method(selection_name):
    return {
        "last_iterate": LastIterate,
        "oracle_model_selection": OracleModelSelection,
        "weight_avg": WeightAVG,
        "beta_moving_avg": BetaMovingAVG,
        "exp_moving_avg": ExpMovingAVG,
    }[selection_name]
